// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Reflection;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;
using WasmAppBuilder;
using JoinedString;

internal sealed class IcallTableGenerator
{
    public string[]? Cookies { get; private set; }

    private List<Icall> _icalls = new List<Icall>();
    private readonly HashSet<string> _signatures = new();
    private Dictionary<string, IcallClass> _runtimeIcalls = new Dictionary<string, IcallClass>();

    private LogAdapter Log { get; set; }
    private readonly Func<string, string> _fixupSymbolName;

    //
    // Given the runtime generated icall table, and a set of assemblies, generate
    // a smaller linked icall table mapping tokens to C function names
    // The runtime icall table should be generated using
    // mono --print-icall-table
    //
    public IcallTableGenerator(string? runtimeIcallTableFile, Func<string, string> fixupSymbolName, LogAdapter log)
    {
        Log = log;
        _fixupSymbolName = fixupSymbolName;
        if (runtimeIcallTableFile != null)
            ReadTable(runtimeIcallTableFile);
    }

    public void ScanAssembly(Assembly asm)
    {
        foreach (Type type in asm.GetTypes())
            ProcessType(type);
    }

    public IEnumerable<string> Generate(string? outputPath)
    {
        if (outputPath != null)
        {
            using TempFileName tmpFileName = new();
            using (var w = new JoinedStringStreamWriter(tmpFileName.Path, false))
                EmitTable(w);

            if (Utils.CopyIfDifferent(tmpFileName.Path, outputPath, useHash: false))
                Log.LogMessage(MessageImportance.Low, $"Generating icall table to '{outputPath}'.");
            else
                Log.LogMessage(MessageImportance.Low, $"Icall table in {outputPath} is unchanged.");
        }

        return _signatures;
    }

    private void EmitTable(StreamWriter w)
    {
        var assemblyMap = new Dictionary<string, string>();
        foreach (var icall in _icalls)
            assemblyMap[icall.Assembly!] = icall.Assembly!;

        foreach (var assembly in assemblyMap.Keys)
        {
            var sorted = _icalls.Where(i => i.Assembly == assembly).ToArray();
            Array.Sort(sorted);

            string aname = assembly == "System.Private.CoreLib" ? "corlib" : _fixupSymbolName(assembly);

            w.Write(
                $$"""

                #define ICALL_TABLE_{{aname}} 1

                static int {{aname}}_icall_indexes [] = {
                    {{sorted.Join($",{w.NewLine}    ", (icall, i) => $"/* {i} */ {icall.TokenIndex}")}}
                };

                {{sorted.Join($" {w.NewLine}", GenIcallDecl)}}

                static void *{{aname}}_icall_funcs [] = {
                    {{sorted.Join($",{w.NewLine}    ", (icall, i) => $"/* {i}:{icall.TokenIndex} */ {icall.Func}" )}}
                };

                static uint8_t {{aname}}_icall_flags [] = {
                    {{sorted.Join($",{w.NewLine}    ", (icall, i) => $"/* {i}:{icall.TokenIndex} */ {icall.Flags}")}}
                };

                """);
        }
    }

    // Read the icall table generated by mono --print-icall-table
    private void ReadTable(string filename)
    {
        using var stream = File.OpenRead(filename);
        using JsonDocument json = JsonDocument.Parse(stream);

        var arr = json.RootElement;
        foreach (var v in arr.EnumerateArray())
        {
            var className = v.GetProperty("klass").GetString()!;
            if (className == "")
                // Dummy value
                continue;

            var icallClass = new IcallClass(className);
            _runtimeIcalls[icallClass.Name] = icallClass;
            foreach (var icall_j in v.GetProperty("icalls").EnumerateArray())
            {
                if (!icall_j.TryGetProperty("name", out var nameElem))
                    continue;

                string name = nameElem.GetString()!;
                string func = icall_j.GetProperty("func").GetString()!;
                bool handles = icall_j.GetProperty("handles").GetBoolean();
                int flags = icall_j.TryGetProperty ("flags", out var _) ? int.Parse (icall_j.GetProperty("flags").GetString()!) : 0;

                icallClass.Icalls.Add(name, new Icall(name, func, handles, flags));
            }
        }
    }

    private void ProcessType(Type type)
    {
        foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
        {
            if ((method.GetMethodImplementationFlags() & MethodImplAttributes.InternalCall) == 0)
                continue;

            try
            {
                AddSignature(type, method);
            }
            catch (Exception ex) when (ex is not LogAsErrorException)
            {
                Log.Warning("WASM0001", $"Could not get icall, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
                continue;
            }

            var className = method.DeclaringType!.FullName!;
            if (!_runtimeIcalls.ContainsKey(className))
                // Registered at runtime
                continue;

            var icallClass = _runtimeIcalls[className];

            Icall? icall = null;

            // Try name first
            icallClass.Icalls.TryGetValue(method.Name, out icall);
            if (icall == null)
            {
                string? methodSig = BuildSignature(method, className);
                if (methodSig != null)
                    icallClass.Icalls.TryGetValue(methodSig, out icall);

                if (icall == null)
                    // Registered at runtime
                    continue;
            }

            icall.Method = method;
            icall.TokenIndex = (int)method.MetadataToken & 0xffffff;
            icall.Assembly = method.DeclaringType.Module.Assembly.GetName().Name;
            _icalls.Add(icall);
        }

        foreach (var nestedType in type.GetNestedTypes())
            ProcessType(nestedType);

        string? BuildSignature(MethodInfo method, string className)
        {
            // Then with signature
            var sig = new StringBuilder(method.Name + "(");
            int pindex = 0;
            foreach (var par in method.GetParameters())
            {
                if (pindex > 0)
                    sig.Append(',');

                var t = par.ParameterType;
                try
                {
                    AppendType(sig, t);
                }
                catch (NotImplementedException nie)
                {
                    Log.Warning("WASM0001", $"Failed to generate icall function for method '[{method.DeclaringType!.Assembly.GetName().Name}] {className}::{method.Name}'" +
                                    $" because type '{nie.Message}' is not supported for parameter named '{par.Name}'. Ignoring.");
                    return null;
                }
                pindex++;
            }
            sig.Append(')');

            return sig.ToString();
        }

        void AddSignature(Type type, MethodInfo method)
        {
            string? signature = SignatureMapper.MethodToSignature(method, Log);
            if (signature == null)
            {
                throw new LogAsErrorException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
            }

            if (_signatures.Add(signature))
                Log.LogMessage(MessageImportance.Low, $"Adding icall signature {signature} for method '{type.FullName}.{method.Name}'");
        }
    }

    // Append the type name used by the runtime icall tables
    private void AppendType(StringBuilder sb, Type t)
    {
        if (t.IsArray)
        {
            AppendType(sb, t.GetElementType()!);
            sb.Append("[]");
        }
        else if (t.IsByRef)
        {
            AppendType(sb, t.GetElementType()!);
            sb.Append('&');
        }
        else if (t.IsPointer)
        {
            AppendType(sb, t.GetElementType()!);
            sb.Append('*');
        }
        else if (t.IsEnum)
        {
            AppendType(sb, Enum.GetUnderlyingType(t));
        }
        else
        {
            sb.Append(t.Name switch
            {
                nameof(Char) => "char",
                nameof(Boolean) => "bool",
                nameof(SByte) => "sbyte",
                nameof(Byte) => "byte",
                nameof(Int16) => "int16",
                nameof(UInt16) => "uint16",
                nameof(Int32) => "int",
                nameof(UInt32) => "uint",
                nameof(Int64) => "long",
                nameof(UInt64) => "ulong",
                nameof(IntPtr) => "intptr",
                nameof(UIntPtr) => "uintptr",
                nameof(Single) => "single",
                nameof(Double) => "double",
                nameof(Object) => "object",
                nameof(String) => "string",
                _ => throw new NotImplementedException(t.FullName)
            });
        }
    }

    private static string MapType(Type t) => t.Name switch
    {
        "Void" => "void",
        nameof(Double) => "double",
        nameof(Single) => "float",
        nameof(Int64) => "int64_t",
        nameof(UInt64) => "uint64_t",
        _ => "int",
    };


    private static string GenIcallDecl(Icall icall)
    {
        List<string> args = new();
        var method = icall.Method!;
        if (!method.IsStatic)
            args.Add("int");
        args.AddRange(method.GetParameters().Select(p => MapType(p.ParameterType)));
        if (icall.Handles)
            args.Add("int");

        return $"{MapType(method.ReturnType)} {icall.Func} ({args.Join(", ")});";
    }

    private sealed class Icall : IComparable<Icall>
    {
        public Icall(string name, string func, bool handles, int flags)
        {
            Name = name;
            Func = func;
            Flags = flags;
            Handles = handles;
            TokenIndex = 0;
        }

        public string Name;
        public string Func;
        public string? Assembly;
        public bool Handles;
        public int Flags;
        public int TokenIndex;
        public MethodInfo? Method;

        public int CompareTo(Icall? other)
        {
            return TokenIndex - other!.TokenIndex;
        }
    }

    private sealed class IcallClass
    {
        public IcallClass(string name)
        {
            Name = name;
            Icalls = new Dictionary<string, Icall>();
        }

        public string Name;
        public Dictionary<string, Icall> Icalls;
    }
}
